import sys
import torch
import torch.nn.functional as F
from torch.autograd import Variable

def get_token2text(token, args, id2label):
    try:  # handle case for CTC
        for ut_gold in token:
            str_gold = ""
            for x in ut_gold:
                if int(x) == args.PAD_TOKEN:
                    break
                str_gold = str_gold + id2label[int(x)]
    except Exception as e:
        print(e)
        sys.exit(0)

    str_gold = str_gold.replace(args.SOS_CHAR, '').replace(args.EOS_CHAR, '')

    return str_gold

def calculate_metrics(pred, gold, args, input_lengths=None, target_lengths=None, smoothing=0.0, loss_type="ce"):
    """
    Calculate metrics
    args:
        pred: B x T x C
        gold: B x T
        input_lengths: B (for CTC)
        target_lengths: B (for CTC)
    """
    loss = calculate_loss(pred, gold, args, input_lengths, target_lengths, smoothing, loss_type)
    if loss_type == "ce":
        pred = pred.view(-1, pred.size(2))  # (B*T) x C
        gold = gold.contiguous().view(-1)  # (B*T)
        pred = pred.max(1)[1]
        non_pad_mask = gold.ne(args.PAD_TOKEN)
        num_correct = pred.eq(gold)
        num_correct = num_correct.masked_select(non_pad_mask).sum().item()
        return loss, num_correct
    elif loss_type == "ctc":
        return loss, None
    else:
        print("loss is not defined")
        return None, None


def calculate_loss(pred, gold, args, input_lengths=None, target_lengths=None, smoothing=0.0, loss_type="ce"):
    """
    Calculate loss
    args:
        pred: B x T x C
        gold: B x T
        input_lengths: B (for CTC)
        target_lengths: B (for CTC)
        smoothing:
        type: ce|ctc (ctc => pytorch 1.0.0 or later)
        input_lengths: B (only for ctc)
        target_lengths: B (only for ctc)
    """
    if loss_type == "ce":
        pred = pred.view(-1, pred.size(2))  # (B*T) x C
        gold = gold.contiguous().view(-1)  # (B*T)
        if smoothing > 0.0:
            eps = smoothing
            num_class = pred.size(1)

            gold_for_scatter = gold.ne(args.PAD_TOKEN).long() * gold
            one_hot = torch.zeros_like(pred).scatter(1, gold_for_scatter.view(-1, 1), 1)
            one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / num_class
            log_prob = F.log_softmax(pred, dim=1)

            non_pad_mask = gold.ne(args.PAD_TOKEN)
            num_word = non_pad_mask.sum().item()
            loss = -(one_hot * log_prob).sum(dim=1)
            loss = loss.masked_select(non_pad_mask).sum() / num_word
        else:
            loss = F.cross_entropy(pred - 20, gold, ignore_index=args.PAD_TOKEN, reduction="mean")
    else:
        print("loss is not defined")

    return loss



def get_loss(model, args, benign, benign_lengths, benign_percentages, adv_tgt, adv_tgt_lengths, id2label):

    pred, gold, hyp_seq, gold_seq = model(benign, benign_lengths, adv_tgt, verbose=False)
    hyp_seq = hyp_seq.cpu()
    gold_seq = gold_seq.cpu()

    try:  # handle case for CTC
        strs_gold, strs_hyps = [], []
        for ut_gold in gold_seq:
            str_gold = ""
            for x in ut_gold:
                if int(x) == args.PAD_TOKEN:
                    break
                str_gold = str_gold + id2label[int(x)]
            strs_gold.append(str_gold)
        for ut_hyp in hyp_seq:
            str_hyp = ""
            for x in ut_hyp:
                if int(x) == args.PAD_TOKEN:
                    break
                str_hyp = str_hyp + id2label[int(x)]
            strs_hyps.append(str_hyp)
    except Exception as e:
        print(e)
        sys.exit(0)

    seq_length = pred.size(1)
    sizes = Variable(benign_percentages.mul_(int(seq_length)).int(), requires_grad=False)

    loss, num_correct = calculate_metrics(
        pred, gold, args, input_lengths=sizes, target_lengths=adv_tgt_lengths,
        smoothing=args.label_smoothing, loss_type=args.loss)

    strs_hyps[0] = strs_hyps[0].replace(args.SOS_CHAR, '').replace(args.EOS_CHAR, '')
    strs_gold[0] = strs_gold[0].replace(args.SOS_CHAR, '').replace(args.EOS_CHAR, '')

    if loss.item() == float('Inf'):
        loss = torch.where(loss != loss, torch.zeros_like(loss), loss)  # NaN masking

    return loss, num_correct, strs_hyps[0], strs_gold[0]

def calculate_SNR(perturbations, y):
    energy_perturbations = torch.pow(perturbations, 2).sum()
    energy_y = torch.pow(y, 2).sum()
    SNR = 10 * torch.log10(energy_y / (energy_perturbations + 1e-8))
    return SNR